--- external/stylegan/training/training_loop.py	2023-04-06 03:45:49.515150152 +0000
+++ external_reference/stylegan/training/training_loop.py	2023-04-06 03:41:03.318609421 +0000
@@ -24,14 +24,17 @@
 from torch_utils.ops import grid_sample_gradfix
 
 import legacy
-from metrics import metric_main
+from external.stylegan.metrics import metric_main
+
+from utils import camera_util
+from collections import defaultdict
 
 #----------------------------------------------------------------------------
 
 def setup_snapshot_image_grid(training_set, random_seed=0):
     rnd = np.random.RandomState(random_seed)
-    gw = np.clip(7680 // training_set.image_shape[2], 7, 32)
-    gh = np.clip(4320 // training_set.image_shape[1], 4, 32)
+    gw = np.clip(1024// training_set.image_shape[2], 7, 32)
+    gh = np.clip(1024// training_set.image_shape[1], 4, 32)
 
     # No labels => show random subset of training samples.
     if not training_set.has_labels:
@@ -61,29 +64,57 @@
             grid_indices += [indices[x % len(indices)] for x in range(gw)]
             label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))]
 
-    # Load data.
-    images, labels = zip(*[training_set[i] for i in grid_indices])
-    return (gw, gh), np.stack(images), np.stack(labels)
+    # Load data -- modified for rgb/depth/mask loader
+    rgbs, depths, origs, accs, Rts, labels = [], [], [], [], [], []
+    for i in grid_indices:
+        image_info, label = training_set[i]
+        # (rgb, depth, acc, K, Rt) = image_info
+        rgbs.append(image_info['rgb'])
+        depths.append(image_info['depth'])
+        origs.append(image_info['orig'])
+        accs.append(image_info['acc'])
+        Rts.append(image_info['Rt'])
+        labels.append(label)
+    rgbs = np.stack(rgbs)
+    depths = np.stack(depths)
+    origs = np.stack(origs)
+    accs = np.stack(accs)
+    Rts = np.stack(Rts)
+    # images, labels = zip(*[training_set[i] for i in grid_indices])
+    return (gw, gh), {'rgb': rgbs, 'depth': depths, 'orig': origs, 'acc': accs, 'Rt': Rts}, np.stack(labels)
 
 #----------------------------------------------------------------------------
 
-def save_image_grid(img, fname, drange, grid_size):
-    lo, hi = drange
-    img = np.asarray(img, dtype=np.float32)
-    img = (img - lo) * (255 / (hi - lo))
-    img = np.rint(img).clip(0, 255).astype(np.uint8)
-
-    gw, gh = grid_size
-    _N, C, H, W = img.shape
-    img = img.reshape([gh, gw, C, H, W])
-    img = img.transpose(0, 3, 1, 4, 2)
-    img = img.reshape([gh * H, gw * W, C])
-
-    assert C in [1, 3]
-    if C == 1:
-        PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
-    if C == 3:
-        PIL.Image.fromarray(img, 'RGB').save(fname)
+def save_image_grid(img_infos, fname_base, drange, grid_size):
+    # modified to save rgb, depth, mask outputs
+    if not isinstance(img_infos, dict):
+        img_infos = {'rgb': img_infos}
+
+    rgb_drange=drange
+    for key, img in img_infos.items():
+        if key not in ['rgb', 'depth', 'acc', 'orig']:
+            continue
+        if key in ['rgb', 'orig']:
+            drange = rgb_drange
+        else:
+            drange = [0, 1]
+        fname = fname_base.replace('.png', '-%s.png' % key)
+        lo, hi = drange
+        img = np.asarray(img, dtype=np.float32)
+        img = (img - lo) * (255 / (hi - lo))
+        img = np.rint(img).clip(0, 255).astype(np.uint8)
+
+        gw, gh = grid_size
+        _N, C, H, W = img.shape
+        img = img.reshape([gh, gw, C, H, W])
+        img = img.transpose(0, 3, 1, 4, 2)
+        img = img.reshape([gh * H, gw * W, C])
+
+        assert C in [1, 3]
+        if C == 1:
+            PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
+        if C == 3:
+            PIL.Image.fromarray(img, 'RGB').save(fname)
 
 #----------------------------------------------------------------------------
 
@@ -120,7 +151,13 @@
     cudnn_benchmark         = True,     # Enable torch.backends.cudnn.benchmark?
     abort_fn                = None,     # Callback function for determining whether to abort training. Must return consistent results across ranks.
     progress_fn             = None,     # Callback function for updating training progress. Called for all ranks.
+    training_mode           = None,     # which training mode to use
+    wrapper_kwargs          = {},       # model wrapper arguments
+    decoder_kwargs          = {},       # additional arguments for layout decoder 
+    torgb_kwargs            = {},       # additional arguments for layout torgb layer
+
 ):
+
     # Initialize.
     start_time = time.time()
     device = torch.device('cuda', rank)
@@ -132,6 +169,9 @@
     conv2d_gradfix.enabled = True                       # Improves training speed.
     grid_sample_gradfix.enabled = True                  # Avoids errors with the augmentation pipe.
 
+    # added to prevent data_loader pin_memory to load to device 0 for every process
+    torch.cuda.set_device(device)
+
     # Load training set.
     if rank == 0:
         print('Loading training set...')
@@ -148,9 +188,60 @@
     # Construct networks.
     if rank == 0:
         print('Constructing networks...')
-    common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels)
-    G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
-    D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
+    # common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels)
+    # G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
+    # D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
+    if training_mode == 'layout':
+        from models.layout import model_layout
+        from external.gsn.models.model_utils import TrajectorySampler
+        voxel_size = wrapper_kwargs.voxel_size
+        voxel_res = wrapper_kwargs.voxel_res
+        image_infos,  _ = training_set[0]
+        trajectory_sampler = TrajectorySampler(real_Rts=torch.from_numpy(training_set.Rt).float(), mode='sample').to(device)
+        # decoder part common kwargs
+        G_kwargs.c_dim = training_set.label_dim
+        # discriminator common kwargs
+        D_kwargs.c_dim = training_set.label_dim
+        D_kwargs.img_resolution=training_set.resolution
+        D_kwargs.img_channels=training_set.num_channels+int(loss_kwargs.loss_layout_kwargs.concat_acc)+int(loss_kwargs.loss_layout_kwargs.concat_depth)
+        G = model_layout.ModelLayout(G_kwargs, decoder_kwargs, torgb_kwargs, **wrapper_kwargs).train().requires_grad_(False).to(device)
+        G.set_trajectory_sampler(trajectory_sampler=trajectory_sampler)
+        if not loss_kwargs.loss_layout_kwargs.use_wrapped_discriminator:
+            D = dnnlib.util.construct_class_by_name(**D_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
+        else:
+            # use two discriminators: one on RGBD, and one on sky mask
+            assert loss_kwargs.loss_layout_kwargs.concat_acc
+            D_kwargs.img_channels=training_set.num_channels+int(loss_kwargs.loss_layout_kwargs.concat_depth)
+            D_img = dnnlib.util.construct_class_by_name(**D_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
+            D_kwargs_acc = copy.deepcopy(D_kwargs)
+            D_kwargs_acc.img_channels=1
+            D_acc = dnnlib.util.construct_class_by_name(**D_kwargs_acc).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
+            from models.misc.networks import WrappedDiscriminator
+            wrappedD = WrappedDiscriminator(D_img, D_acc).train().requires_grad_(False).to(device)
+            D = wrappedD
+    elif training_mode == 'upsampler':
+        common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution)
+        G_kwargs.img_channels = training_set.num_channels + G_kwargs.num_additional_feature_channels
+        # G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
+        D_img_channels = G_kwargs.img_channels
+        D_kwargs.img_channels = G_kwargs.img_channels
+        D_kwargs.recon = False # no reconstruction part
+        D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
+        from models.layout.model_terrain import ModelTerrain
+        wrapper = ModelTerrain(**G_kwargs, **common_kwargs, **wrapper_kwargs)
+        G = wrapper.train().requires_grad_(False).to(device)
+    elif training_mode == 'sky':
+        common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels)
+        G_kwargs.enc_dim = 512
+        G_kwargs.training_mode = 'global-360'
+        G_kwargs.fov = int(training_set.fov_mean)
+        G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
+        D_kwargs.recon = False # no reconstruction part
+        D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
+        from models.sky.model_sky import ModelSky
+        wrapper = ModelSky(G)
+        G = wrapper.train().requires_grad_(False).to(device)
+
     G_ema = copy.deepcopy(G).eval()
 
     # Resume from existing pickle.
@@ -158,15 +249,33 @@
         print(f'Resuming from "{resume_pkl}"')
         with dnnlib.util.open_url(resume_pkl) as f:
             resume_data = legacy.load_network_pkl(f)
-        for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
-            misc.copy_params_and_buffers(resume_data[name], module, require_all=False)
+        if training_mode == 'layout' and loss_kwargs.loss_layout_kwargs.use_wrapped_discriminator:
+            for name, module in [('G', G), ('D', D.D_img), ('G_ema', G_ema)]:
+                # assume that resume is from a checkpoint without wrapped discriminator
+                misc.copy_params_and_buffers(resume_data[name], module, require_all=False)
+        else:
+            for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
+                misc.copy_params_and_buffers(resume_data[name], module, require_all=False)
 
     # Print network summary tables.
     if rank == 0:
         z = torch.empty([batch_gpu, G.z_dim], device=device)
         c = torch.empty([batch_gpu, G.c_dim], device=device)
-        img = misc.print_module_summary(G, [z, c])
-        misc.print_module_summary(D, [img, c])
+        if training_mode == 'layout': 
+            camera_params = camera_util.get_full_image_parameters(
+                G, G.layout_decoder_kwargs.nerf_out_res,
+                batch_size=batch_gpu, device=z.device, Rt=None)
+            img, _ = misc.print_module_summary(G, [z, c, camera_params])
+            img = torch.empty([batch_gpu, D.img_channels, D.img_resolution, D.img_resolution], device=device)
+            misc.print_module_summary(D, [img, c])
+        elif training_mode == 'upsampler':
+            img, thumb = misc.print_module_summary(G, [z, c])
+            img_for_D = torch.empty([batch_gpu, D.img_channels, D.img_resolution, D.img_resolution], device=device)
+            misc.print_module_summary(D, [img_for_D, c])
+        elif training_mode == 'sky':
+            ref_img = torch.empty([batch_gpu, G.img_channels, G.img_resolution, G.img_resolution], device=device)
+            img = misc.print_module_summary(G, [z, c, ref_img, ref_img])
+            misc.print_module_summary(D, [img, c])
 
     # Setup augmentation.
     if rank == 0:
@@ -183,17 +292,24 @@
     if rank == 0:
         print(f'Distributing across {num_gpus} GPUs...')
     for module in [G, D, G_ema, augment_pipe]:
-        if module is not None:
+        if module is not None and num_gpus > 1:
             for param in misc.params_and_buffers(module):
-                if param.numel() > 0 and num_gpus > 1:
-                    torch.distributed.broadcast(param, src=0)
+                torch.distributed.broadcast(param, src=0)
 
     # Setup training phases.
     if rank == 0:
         print('Setting up training phases...')
+    # loss kwargs needs training mode and loss kwargs by mode
+    loss_kwargs.training_mode = training_mode
     loss = dnnlib.util.construct_class_by_name(device=device, G=G, D=D, augment_pipe=augment_pipe, **loss_kwargs) # subclass of training.loss.Loss
     phases = []
     for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]:
+        if training_mode == 'sky' and name == 'G':
+            optim_params = module.G.parameters()
+        elif training_mode == 'upsampler' and name == 'G':
+            optim_params = module.upsampler.parameters()
+        else:
+            optim_params = module.parameters()
         if reg_interval is None:
             opt = dnnlib.util.construct_class_by_name(params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
             phases += [dnnlib.EasyDict(name=name+'both', module=module, opt=opt, interval=1)]
@@ -222,8 +338,49 @@
         save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size)
         grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu)
         grid_c = torch.from_numpy(labels).to(device).split(batch_gpu)
-        images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
-        save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
+        # custom image saving for each training mode
+        if training_mode == 'layout':
+            images = defaultdict(list)
+            for z, c in zip(grid_z, grid_c):
+                camera_params = camera_util.get_full_image_parameters(
+                    G, G.layout_decoder_kwargs.nerf_out_res,
+                    batch_size=z.shape[0], device=z.device, Rt=None)
+                _, infos = G_ema(z=z, c=c, camera_params=camera_params, noise_mode='const')
+                for k, v in infos.items():
+                    if k in ['rgb', 'depth', 'acc']:
+                        images[k].append(v.cpu())
+            images = {k: torch.cat(v).numpy() for k, v in images.items()}
+            save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
+        elif training_mode == 'upsampler':
+            images_fake, images_fake_thumb = [], []
+            for z, c in zip(grid_z, grid_c):
+                im_fake, im_fake_thumb = G_ema(z=z, c=c, noise_mode='const')
+                images_fake.append(im_fake.cpu())
+                images_fake_thumb.append(im_fake_thumb.cpu())
+            images_fake = torch.cat(images_fake).numpy()
+            images_fake_thumb = torch.cat(images_fake_thumb).numpy()
+            images_fake_infos = {'rgb': images_fake[:, :3]}
+            images_fake_thumb_infos = {'rgb': images_fake_thumb[:, :3]}
+            if images_fake.shape[1] > 3:
+                images_fake_infos['depth'] = images_fake[:, 3:4]
+                images_fake_thumb_infos['depth'] = images_fake_thumb[:, 3:4]
+            if images_fake.shape[1] > 4:
+                images_fake_infos['acc'] = images_fake[:, 4:5]
+                images_fake_thumb_infos['acc'] = images_fake_thumb[:, 4:5]
+            save_image_grid(images_fake_infos, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
+            save_image_grid(images_fake_thumb_infos, os.path.join(run_dir, 'fakes_init_thumb.png'), drange=[-1,1], grid_size=grid_size)
+        elif training_mode == 'sky':
+            images_masked = torch.from_numpy((images['rgb'] / 127.5) - 1).to(device).split(batch_gpu)
+            images_acc = torch.from_numpy(images['acc']).to(device).split(batch_gpu)
+            images_fake = torch.cat([G_ema(z=z, c=c, img=im_masked, acc=im_acc, noise_mode='const').cpu()
+                                     for z, c, im_masked, im_acc in zip(grid_z, grid_c, images_masked, images_acc)]).numpy()
+            save_image_grid(images_fake, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
+            images_fake_nomask = torch.cat([G_ema(z=z, c=c, img=im_masked,
+                                                  acc=im_acc, multiply=False, noise_mode='const').cpu()
+                                     for z, c, im_masked, im_acc in zip(grid_z, grid_c, images_masked, images_acc)]).numpy()
+            save_image_grid(images_fake_nomask, os.path.join(run_dir, 'fakes_init_nomask.png'), drange=[-1,1], grid_size=grid_size)
+        # images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
+        # save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
 
     # Initialize logs.
     if rank == 0:
@@ -256,8 +413,30 @@
 
         # Fetch training data.
         with torch.autograd.profiler.record_function('data_fetch'):
-            phase_real_img, phase_real_c = next(training_set_iterator)
-            phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu)
+            phase_real_infos, phase_real_c = next(training_set_iterator)
+            # masked image without sky
+            phase_real_img = (phase_real_infos['rgb'].to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu)
+            # full image with sky
+            phase_real_orig = (phase_real_infos['orig'].to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu)
+            # phase_real_depth is actually inverse depth if use_disp=True 
+            phase_real_depth = phase_real_infos['depth'].to(device).to(torch.float32).split(batch_gpu)
+            phase_real_acc = phase_real_infos['acc'].to(device).to(torch.float32).split(batch_gpu)
+            # zeroes the sky out, s.t. the sky maps exactly to zero for rgb and disparity
+            phase_real_img = [img * acc for img, acc in zip(phase_real_img, phase_real_acc)]
+            phase_real_depth = [depth * acc for depth, acc in zip(phase_real_depth, phase_real_acc)]
+            phase_real_K = phase_real_infos['K'].to(device).to(torch.float32).split(batch_gpu)
+            phase_real_Rt = phase_real_infos['Rt'].to(device).to(torch.float32).split(batch_gpu)
+            phase_real_size = phase_real_infos['global_size'].to(device).split(batch_gpu)
+            phase_real_fov = phase_real_infos['fov'].to(device).split(batch_gpu)
+            phase_real_infos = [dict(rgb=rgb, depth=depth, acc=acc, orig=orig,
+                                     camera_params=dict(K=K, Rt=Rt,
+                                     global_size=size, fov=fov))
+                                for (rgb, depth, acc, orig, K, Rt, size, fov)
+                                in zip(phase_real_img, phase_real_depth,
+                                       phase_real_acc, phase_real_orig,
+                                       phase_real_K, phase_real_Rt,
+                                       phase_real_size, phase_real_fov)]
+
             phase_real_c = phase_real_c.to(device).split(batch_gpu)
             all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device)
             all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size)]
@@ -275,13 +454,13 @@
             # Accumulate gradients.
             phase.opt.zero_grad(set_to_none=True)
             phase.module.requires_grad_(True)
-            for real_img, real_c, gen_z, gen_c in zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c):
-                loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, gain=phase.interval, cur_nimg=cur_nimg)
+            for real_img_infos, real_c, gen_z, gen_c in zip(phase_real_infos, phase_real_c, phase_gen_z, phase_gen_c):
+                loss.accumulate_gradients(phase=phase.name, real_img_infos=real_img_infos, real_c=real_c, gen_z=gen_z, gen_c=gen_c, gain=phase.interval, cur_nimg=cur_nimg)
             phase.module.requires_grad_(False)
 
             # Update weights.
             with torch.autograd.profiler.record_function(phase.name + '_opt'):
-                params = [param for param in phase.module.parameters() if param.numel() > 0 and param.grad is not None]
+                params = [param for param in phase.module.parameters() if param.grad is not None]
                 if len(params) > 0:
                     flat = torch.cat([param.grad.flatten() for param in params])
                     if num_gpus > 1:
@@ -351,21 +530,66 @@
 
         # Save image snapshot.
         if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0):
-            images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
-            save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size)
+            # images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
+            # save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size)
+            # custom image saving for each training mode
+            if training_mode == 'layout':
+                images = defaultdict(list)
+                for z, c in zip(grid_z, grid_c):
+                    camera_params = camera_util.get_full_image_parameters(
+                        G, G.layout_decoder_kwargs.nerf_out_res,
+                        batch_size=z.shape[0], device=z.device, Rt=None)
+                    _, infos = G_ema(z=z, c=c, camera_params=camera_params, noise_mode='const')
+                    for k, v in infos.items():
+                        if k in ['rgb', 'depth', 'acc']:
+                            images[k].append(v.cpu())
+                images = {k: torch.cat(v).numpy() for k, v in images.items()}
+                save_image_grid(images,  os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'),drange=[-1,1], grid_size=grid_size)
+            elif training_mode == 'upsampler':
+                images_fake, images_fake_thumb = [], []
+                for z, c in zip(grid_z, grid_c):
+                    im_fake, im_fake_thumb = G_ema(z=z, c=c, noise_mode='const')
+                    images_fake.append(im_fake.cpu())
+                    images_fake_thumb.append(im_fake_thumb.cpu())
+                images_fake = torch.cat(images_fake).numpy()
+                images_fake_thumb = torch.cat(images_fake_thumb).numpy()
+                images_fake_infos = {'rgb': images_fake[:, :3]}
+                images_fake_thumb_infos = {'rgb': images_fake_thumb[:, :3]}
+                if images_fake.shape[1] > 3:
+                    images_fake_infos['depth'] = images_fake[:, 3:4]
+                    images_fake_thumb_infos['depth'] = images_fake_thumb[:, 3:4]
+                if images_fake.shape[1] > 4:
+                    images_fake_infos['acc'] = images_fake[:, 4:5]
+                    images_fake_thumb_infos['acc'] = images_fake_thumb[:, 4:5]
+                save_image_grid(images_fake_infos,  os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'),drange=[-1,1], grid_size=grid_size)
+                save_image_grid(images_fake_thumb_infos,  os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}_thumb.png'),drange=[-1,1], grid_size=grid_size)
+            elif training_mode == 'sky':
+                images_masked = torch.from_numpy((images['rgb'] / 127.5) - 1).to(device).split(batch_gpu)
+                images_acc = torch.from_numpy(images['acc']).to(device).split(batch_gpu)
+                images_fake = torch.cat([G_ema(z=z, c=c, img=im_masked, acc=im_acc, noise_mode='const').cpu()
+                                         for z, c, im_masked, im_acc in zip(grid_z, grid_c, images_masked, images_acc)]).numpy()
+                save_image_grid(images_fake,  os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'),drange=[-1,1], grid_size=grid_size)
+                images_fake_nomask = torch.cat([G_ema(z=z, c=c, img=im_masked,
+                                                      acc=im_acc, multiply=False, noise_mode='const').cpu()
+                                         for z, c, im_masked, im_acc in zip(grid_z, grid_c, images_masked, images_acc)]).numpy()
+                save_image_grid(images_fake,  os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}_nomask.png'),drange=[-1,1], grid_size=grid_size)
 
         # Save network snapshot.
         snapshot_pkl = None
         snapshot_data = None
         if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0):
-            snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs))
-            for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]:
-                if module is not None:
+            snapshot_data = dict(G=G, D=D, G_ema=G_ema, augment_pipe=augment_pipe, training_set_kwargs=dict(training_set_kwargs))
+            for key, value in snapshot_data.items():
+                if isinstance(value, torch.nn.Module):
+                    value = copy.deepcopy(value).eval().requires_grad_(False)
+                    if training_mode == 'sky' and 'G' in key:
+                        value.encoder = None
                     if num_gpus > 1:
-                        misc.check_ddp_consistency(module, ignore_regex=r'.*\.[^.]+_(avg|ema)')
-                    module = copy.deepcopy(module).eval().requires_grad_(False).cpu()
-                snapshot_data[name] = module
-                del module # conserve memory
+                        misc.check_ddp_consistency(value, ignore_regex=r'.*\.[^.]+_(avg|ema)')
+                        for param in misc.params_and_buffers(value):
+                            torch.distributed.broadcast(param, src=0)
+                    snapshot_data[key] = value.cpu()
+                del value # conserve memory
             snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl')
             if rank == 0:
                 with open(snapshot_pkl, 'wb') as f:
@@ -376,8 +600,10 @@
             if rank == 0:
                 print('Evaluating metrics...')
             for metric in metrics:
-                result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'],
-                    dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device)
+                result_dict = metric_main.calc_metric(
+                    metric=metric, G=G_ema, # snapshot_data['G_ema'],
+                    dataset_kwargs=training_set_kwargs, num_gpus=num_gpus,
+                    rank=rank, device=device, training_mode=training_mode)
                 if rank == 0:
                     metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl)
                 stats_metrics.update(result_dict.results)
